import torch
import torch.nn as nn
from torch.nn import functional as F
import yaml
import shutil
import math,os
from tqdm import tqdm
from lightly.utils.debug import std_of_l2_normalized
import torch.distributed as dist

# 用于判断是否是SNN
def is_SNN(model_name):

    model_suffix=["spiking","Spiking","sew" ,"SEW", "snn","SNN"]
    for i in model_suffix:
        if i in model_name:
            return True
    return False

def exchange_conv(encoder,args):
    #设置ann的resnet适应cifar
    if "cifar" in args.dataset or "CIFAR" in args.dataset or "NCALTECH101" in args.dataset:
        if "resnet" in args.arch:
            if not is_SNN(args.arch):
                encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
                encoder.maxpool = nn.Identity()
            else:
                encoder.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=1, padding=1, bias=False)
                encoder.maxpool = nn.Identity()                

class KNN_Classifier:
    """KNN classifier for feature evaluation"""
    def __init__(self, k=200, temperature=0.1):
        self.k = k
        self.temperature = temperature
        self.features_bank = []
        self.targets_bank = []
        
    def update_memory(self, features, targets):
        """Update feature memory bank with new batch of features and targets"""
        self.features_bank.append(features)
        self.targets_bank.append(targets)
        
    def reset(self):
        """Reset memory bank"""
        self.features_bank = []
        self.targets_bank = []
        
    def predict(self, features, targets):
        """Predict using KNN classifier"""
        # Concatenate all stored features and targets
        all_features = torch.cat(self.features_bank, 0)
        all_targets = torch.cat(self.targets_bank, 0)
        
        # L2 normalize the features
        features = F.normalize(features, dim=1)
        all_features = F.normalize(all_features, dim=1)
        
        # Calculate similarity
        sim_matrix = torch.mm(features, all_features.t())
        
        # Apply temperature scaling
        sim_matrix = sim_matrix / self.temperature
        
        # Find top-k nearest neighbors
        sim_weight, sim_indices = sim_matrix.topk(k=min(self.k, all_features.size(0)), dim=1)
        
        # Get corresponding labels
        sim_labels = torch.gather(all_targets.expand(features.size(0), -1), 
                                 dim=1, 
                                 index=sim_indices)
        
        # Apply softmax to similarities
        sim_weight = F.softmax(sim_weight, dim=1)
        
        # Find all possible classes (more robust than just using unique)
        # This handles cases where some classes might not be in the current batch
        num_classes = max(targets.max().item(), all_targets.max().item()) + 1
        
        # Create weighted votes for each class
        pred_scores = torch.zeros(features.size(0), num_classes, device=features.device)
        
        # Use index_add_ for efficient vote aggregation
        for b in range(features.size(0)):
            for k in range(sim_weight.size(1)):
                pred_scores[b, sim_labels[b, k]] += sim_weight[b, k]
        
        # Get predicted classes
        pred_labels = pred_scores.argmax(dim=1)
        
        # Calculate accuracy
        correct = pred_labels.eq(targets).float().sum()
        accuracy = 100 * correct / targets.shape[0]
        
        return accuracy

def calculate_temporal_feature_similarity(feature):
    """
    使用批量操作的优化版本
    """
    # 转换为 [Batch, T, Dim]
    feature = feature.permute(1, 0, 2)
    # L2归一化
    normalized_features = F.normalize(feature, p=2, dim=2) # [Batch, T, Dim]
    # 批量计算相似度矩阵: [Batch, T, T]
    similarity_matrices = torch.bmm(normalized_features, normalized_features.transpose(1, 2))
    # 计算每个样本的平均相似度
    avg_similarities = similarity_matrices.mean() # [Batch]
    return avg_similarities

# test using a knn monitor
def knn_test(net, memory_data_loader, test_data_loader, args):
    net.eval()
    classes = len(memory_data_loader.dataset.classes)
    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
    total_std = AverageMeter('feature_std')
    total_temporal_feature_similarity = AverageMeter('temporal_feature_similarity')
    with torch.no_grad():
        # generate feature bank
        features_list, targets_list = [], []
        for data, target in tqdm(memory_data_loader, desc='Feature extracting'):
            if isinstance(data, list):
                data = data[0]
            data = data.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)
            feature = net(data)
            feature = feature.mean(dim=0)
            feature = F.normalize(feature, dim=1)
            features_list.append(feature)
            targets_list.append(target)
            if len(features_list)*data.size(0)>=args.knn_max_samples:
                break
        # concat local features/labels first -> [N_local, D] and [N_local]
        feature_bank = torch.cat(features_list, dim=0)
        feature_labels = torch.cat(targets_list, dim=0)
        # all_gather across GPUs if distributed -> [N_global, D] and [N_global]
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            feature_bank = concat_all_gather(feature_bank)
            feature_labels = concat_all_gather(feature_labels)
        # transpose for matmul -> [D, N]
        feature_bank = feature_bank.t().contiguous()
        
        # loop test data to predict the label by weighted knn search
        test_bar = tqdm(test_data_loader)
        for data, target in test_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            feature = net(data)
            total_temporal_feature_similarity.update(calculate_temporal_feature_similarity(feature), data.size(0))
            feature = feature.mean(dim=0)
            feature = F.normalize(feature, dim=1)
            feature_std = std_of_l2_normalized(feature)
            pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, args.knn_k, args.knn_t)

            total_num += data.size(0)
            total_top1 += (pred_labels[:, 0] == target).float().sum().item()
            total_top5 += (pred_labels[:, :5] == target.unsqueeze(1)).float().sum().item()
            total_std.update(feature_std, data.size(0))
            test_bar.set_description('knn Acc@1:{:.2f}% Acc@5:{:.2f}% feature_std:{:.4f} temporal_feature_similarity:{:.4f}'.format(total_top1 / total_num * 100, total_top5 / total_num * 100, total_std.avg, total_temporal_feature_similarity.avg))


    return total_top1 / total_num * 100, total_top5 / total_num * 100, total_std.avg, total_temporal_feature_similarity.avg

# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
    # compute cos similarity between each feature vector and feature bank ---> [B, N]
    sim_matrix = torch.mm(feature, feature_bank)
    # [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    # [B, K]
    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
    sim_weight = (sim_weight / knn_t).exp()

    # counts for each class
    one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
    # [B*K, C]
    one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
    # weighted score ---> [B, C]
    pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)

    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels


# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [
        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
    ]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

def load_config(args):
    


    if args.dataset_config is not None:
        dataset_config=yaml.safe_load(open(args.dataset_config, "r"))
    else:
        dataset_config = {}

    if args.method_config is not None:
        method_config=yaml.safe_load(open(args.method_config, "r"))
    else:
        method_config = {}
        
    if args.model_config is not None:
        model_config=yaml.safe_load(open(args.model_config, "r"))
    else:
        model_config = {}
        

    for key, value in dataset_config.items():
        if not hasattr(args, key):
            raise AttributeError(f"Unknown dataset configuration parameter: '{key}'")
        setattr(args, key, value)
    for key, value in method_config.items():
        if not hasattr(args, key):
            raise AttributeError(f"Unknown method configuration parameter: '{key}'")
        setattr(args, key, value)
    for key, value in model_config.items():
        if not hasattr(args, key):
            raise AttributeError(f"Unknown model configuration parameter: '{key}'")
        setattr(args, key, value)
    
     
    return args


def save_checkpoint(state, filename="checkpoint.pth.tar"):
    if not os.path.exists(os.path.dirname(filename)):
        os.makedirs(os.path.dirname(filename), exist_ok=True)
    torch.save(state, filename)



class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=":f"):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)


class ProgressMeter:
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print("\t".join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = "{:" + str(num_digits) + "d}"
        return "[" + fmt + "/" + fmt.format(num_batches) + "]"


def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate based on schedule"""
    lr = args.init_lr

    if epoch < args.warmup_epochs:
        lr = lr * (epoch + 1) / args.warmup_epochs
    else:
        if args.cos:  # cosine lr schedule
            lr *= 0.5 * (1.0 + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
        else:  # stepwise lr schedule
            for milestone in args.schedule:
                lr *= 0.1 if epoch >= milestone else 1.0

    for param_group in optimizer.param_groups:
        if args.fc_lr_scale != 1.0 and "fc" in param_group["name"]:
            param_group["lr"] = lr * args.fc_lr_scale
        else:
            param_group["lr"] = lr
 

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    
    if output is None:
        return [[0]]*len(topk)
    
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def sanity_check(state_dict, pretrained_weights):
    """
    Linear classifier should not change any weights other than the linear layer.
    This sanity check asserts nothing wrong happens (e.g., BN stats updated).
    """
    print("=> loading '{}' for sanity check".format(pretrained_weights))
    checkpoint = torch.load(pretrained_weights, map_location="cpu",weights_only=False)
    state_dict_pre = checkpoint["state_dict"]

    for k in list(state_dict.keys()):
        # only ignore fc layer
        if "fc.weight" in k or "fc.bias" in k:
            continue

        # name in pretrained model
        k_pre = (
            "module.encoder_q." + k[len("module.") :]
            if k.startswith("module.")
            else "module.encoder_q." + k
        )

        assert (
            state_dict[k].cpu() == state_dict_pre[k_pre]
        ).all(), "{} is changed in linear classifier training.".format(k)

    print("=> sanity check passed.")